{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8d5b4237-b10c-4ebb-8295-6492b59a4e03",
   "metadata": {},
   "source": [
    "# Omnivore: A Single Model for Many Visual Modalities\n",
    "\n",
    "Omnivore is a model introduced in [this paper](https://arxiv.org/abs/2201.08377), and it is a classification model that able to accept different visual modalities, standard image (RGB), video, or depth image (RGBD), as the input. This model uses the video swin transformer as the encoder and it has multiple heads corresponding to each visual modality.\n",
    "\n",
    "In this notebook, we want to demonstrate how to use the omnivore model loaded with its pretrained weight to classify each of the visual modalities.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "246d92a5-b27f-423f-a26b-2808a05b89f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torchvision.io import read_video, _read_video_from_file\n",
    "import torchvision.transforms as T\n",
    "\n",
    "import torchmultimodal.models.omnivore as omnivore\n",
    "from examples.omnivore.data import presets\n",
    "\n",
    "from IPython.display import Video\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "404529a3-c132-4412-b6e3-3191fcdba70d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get model with pretrained weight\n",
    "model = omnivore.omnivore_swin_t(pretrained=True)\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1d34f9bf-41d4-4b03-ad84-52a37a7d3c61",
   "metadata": {},
   "source": [
    "## Inference on Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3dcd21d-b04f-4c29-9521-0e6f2bd149b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Downloading assets\n",
    "os.makedirs(\"assets\", exist_ok=True)\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/imagenet_val_ringlet_butterfly_001.JPEG\" -P \"assets/\"\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/imagenet_class.json\" -P \"assets/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "018c166b-9d3a-45db-b880-85a182d713b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read the class list and image\n",
    "with open(\"assets/imagenet_class.json\", \"r\") as fin:\n",
    "    imagenet_classes = json.load(fin)\n",
    "pil_img = Image.open(\"assets/imagenet_val_ringlet_butterfly_001.JPEG\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96c92d8d-06a8-4b2b-b2be-1cb2ff0b4260",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show image\n",
    "pil_img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c4bc450-203e-4928-b72b-60895b52128c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply transforms\n",
    "img_val_presets = presets.ImageNetClassificationPresetEval(crop_size=224)\n",
    "input_img = img_val_presets(pil_img)\n",
    "\n",
    "# Add batch dimension\n",
    "input_img = input_img.unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9e4cdb8-4967-4735-b3f3-f2e349e23466",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get top5 labels\n",
    "preds = model(input_img, \"image\")\n",
    "top5_values, top5_indices = preds[0].topk(5)\n",
    "top5_labels = [imagenet_classes[index] for index in top5_indices.tolist()]\n",
    "top5_labels\n",
    "\n",
    "# The correct label is a ringlet butterfly, see: https://en.wikipedia.org/wiki/Ringlet"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0b9e19b-6cd3-4e04-a850-5400abdd9151",
   "metadata": {},
   "source": [
    "## Inference on Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2503590f-b497-4809-a38c-cdaca53d54fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Downloading assets\n",
    "os.makedirs(\"assets\", exist_ok=True)\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/kinetics400_val_snowboarding_001.mp4\" -P \"assets/\"\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/kinetics400_class.json\" -P \"assets/\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0031161e-af63-4357-9cc2-812f6a75e585",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read class list and video\n",
    "with open(\"assets/kinetics400_class.json\", \"r\") as fin:\n",
    "    kinetics400_classes = json.load(fin)\n",
    "video, audio, info = read_video(\"assets/kinetics400_val_snowboarding_001.mp4\", output_format=\"TCHW\")\n",
    "\n",
    "# Since we sampled at 16 fps for training, and the input video is 30 fps\n",
    "# we resample every 2 frames so it become 15 fps and closer to training fps\n",
    "video = video[::2]\n",
    "\n",
    "# Use first 50 frames\n",
    "video = video[:50]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f002359-4c55-48cc-b762-12ec54da1256",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show video\n",
    "Video(\"assets/kinetics400_val_snowboarding_001.mp4\", width=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "308ebbb4-3a94-4955-be8a-dd6e0f9ed389",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply transforms\n",
    "video_val_presets = presets.VideoClassificationPresetEval(crop_size=224, resize_size=224)\n",
    "input_video = video_val_presets(video)\n",
    "# Add batch dimension\n",
    "input_video = input_video.unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78f4c48a-4e67-482d-b184-e02493c66b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get top5 labels\n",
    "preds = model(input_video, \"video\")\n",
    "top5_values, top5_indices = preds[0].topk(5)\n",
    "top5_labels = [kinetics400_classes[index] for index in top5_indices.tolist()]\n",
    "top5_labels\n",
    "\n",
    "# The correct label is snowboarding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f013c367-68c2-4b81-927e-867019952663",
   "metadata": {},
   "source": [
    "## Inference on depth image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b8268f6-c5ff-4a38-a03c-b03d80b851b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Downloading assets\n",
    "os.makedirs(\"assets\", exist_ok=True)\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_val_kitchen_depth_001.png\" -P \"assets/\"\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_val_kitchen_image_001.jpg\" -P \"assets/\"\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_val_kitchen_intrinsics_001.txt\" -P \"assets/\"\n",
    "!wget \"https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_class.json\" -P \"assets/\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "387e7848-bc6d-4f53-9388-f76372d59371",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read class list\n",
    "with open(\"assets/sunrgbd_class.json\", \"r\") as fin:\n",
    "    sunrgbd_classes = json.load(fin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89de4233-4fa6-40c7-99ef-e58683c16896",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read depth image\n",
    "with open(\"assets/sunrgbd_val_kitchen_intrinsics_001.txt\", \"r\") as fin:\n",
    "    lines = fin.readlines()\n",
    "    focal_length = float(lines[0].strip().split()[0])\n",
    "    \n",
    "# Baseline of kv2 sensor of sunrgbd (where this depth image come from)\n",
    "baseline = 0.075\n",
    "\n",
    "img_depth = Image.open(\"assets/sunrgbd_val_kitchen_depth_001.png\")\n",
    "_to_tensor = T.ToTensor()\n",
    "tensor_depth = _to_tensor(img_depth)\n",
    "tensor_disparity = baseline * focal_length / (tensor_depth / 1000.0)\n",
    "\n",
    "img_rgb = Image.open(\"assets/sunrgbd_val_kitchen_image_001.jpg\")\n",
    "tensor_rgb = _to_tensor(img_rgb)\n",
    "\n",
    "tensor_rgbd = torch.cat((tensor_rgb, tensor_disparity), dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a2830b-7abb-4803-9be8-1abc3b1770c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show depth image\n",
    "fig = plt.figure(figsize=(16, 16))\n",
    "\n",
    "fig.add_subplot(1, 2, 1)\n",
    "plt.imshow(np.asarray(img_rgb))\n",
    "\n",
    "fig.add_subplot(1, 2, 2)\n",
    "plt.imshow(np.asarray(img_depth), cmap=\"jet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "403921be-d0ea-4fc1-bc5c-c5f0c1083b6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply transforms\n",
    "depth_val_presets = presets.DepthClassificationPresetEval(crop_size=224, resize_size=224)\n",
    "input_depth = depth_val_presets(tensor_rgbd)\n",
    "# Add batch dimension\n",
    "input_depth = input_depth.unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54648603-d427-4506-b146-29dd65f62fc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get top5 predictions\n",
    "preds = model(input_depth, \"rgbd\")\n",
    "top5_values, top5_indices = preds[0].topk(5)\n",
    "top5_labels = [sunrgbd_classes[index] for index in top5_indices.tolist()]\n",
    "top5_labels\n",
    "\n",
    "# The correct label is kitchen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b283aa1-ce7e-403d-bd16-138c36d0846d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
